{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一.参数的梯度下降求解\n",
    "\n",
    "#### 损失函数推导\n",
    "CRF的参数学习同HMM一样，采用的极大似然估计的方式，其对数似然函数为：   \n",
    "\n",
    "$$\n",
    "L(w)=L_{\\tilde{P}}(P_w)\\\\\n",
    "=log\\prod_{x,y}P_w(y\\mid x)^{\\tilde{P}(x,y)}\\\\\n",
    "=\\sum_{x,y}\\tilde{P}(x,y)logP_w(y\\mid x)\\\\\n",
    "=\\sum_{j=1}^NlogP_w(y_j\\mid x_j)（假设样本量为N）\n",
    "$$  \n",
    "\n",
    "将CRF的势函数：   \n",
    "\n",
    "$$\n",
    "P_w(y\\mid x)=\\frac{exp(\\sum_{k=1}^Kw_kf_k(y,x))}{Z_w(x)}\n",
    "$$  \n",
    "\n",
    "带入上面的对数似然函数可得：   \n",
    "\n",
    "$$\n",
    "L(w)=\\sum_{j=1}^N\\sum_{k=1}^Kw_kf_k(y_j,x_j)-\\sum_{j=1}^NlogZ_w(x_j)\n",
    "$$  \n",
    "\n",
    "最终，我们的问题等价于：   \n",
    "\n",
    "$$\n",
    "w_k^*=arg\\min_{w_k}\\sum_{j=1}^N(logZ_w(x_j)-\\sum_{k=1}^Kw_kf_k(y_j,x_j))\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 梯度下降\n",
    "\n",
    "最简单的就是梯度下降，可推导出其梯度函数为：     \n",
    "\n",
    "$$\n",
    "g(w^t)=\\sum_{j=1}^N(P_{w^t}(y_j\\mid x_j)-1)F(y_j,x_j))\n",
    "$$  \n",
    "\n",
    "这里，$F(y_j,x_j)=[f_1(y_j,x_j),f_2(y_j,x_j),...,f_K(y_j,x_j)]^T$，所以梯度更新公式也就出来了：   \n",
    "\n",
    "$$\n",
    "w^{t+1}=w^t-\\eta g(w^t)\n",
    "$$  \n",
    "\n",
    "这里，$\\eta$为学习率，用梯度更新公式来理解一下训练过程，每次迭代会将特征函数输出为1的权重增加，而输出为0的保持不变，则看起来也蛮合理的~~~"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二.代码实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "from ml_models.pgm import CRFFeatureFunction\n",
    "import numpy as np\n",
    "\n",
    "\"\"\"\n",
    "线性链CRF的实现，封装到ml_models.pgm\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class CRF(object):\n",
    "    def __init__(self, epochs=10, lr=1e-3, tol=1e-5, output_status_num=None, input_status_num=None, unigram_rulers=None,\n",
    "                 bigram_rulers=None):\n",
    "        \"\"\"\n",
    "        :param epochs: 迭代次数\n",
    "        :param lr: 学习率\n",
    "        :param tol:梯度更新的阈值\n",
    "        :param output_status_num:标签状态数\n",
    "        :param input_status_num:输入状态数\n",
    "        :param unigram_rulers: 状态特征规则\n",
    "        :param bigram_rulers: 状态转移规则\n",
    "        \"\"\"\n",
    "        self.epochs = epochs\n",
    "        self.lr = lr\n",
    "        self.tol = tol\n",
    "        # 为输入序列和标签状态序列添加一个头尾id\n",
    "        self.output_status_num = output_status_num + 2\n",
    "        self.input_status_num = input_status_num + 2\n",
    "        self.input_status_head_tail = [input_status_num, input_status_num + 1]\n",
    "        self.output_status_head_tail = [output_status_num, output_status_num + 1]\n",
    "        # 特征函数\n",
    "        self.FF = CRFFeatureFunction(unigram_rulers, bigram_rulers)\n",
    "        # 模型参数\n",
    "        self.w = None\n",
    "\n",
    "    def fit(self, x, y):\n",
    "        \"\"\"\n",
    "        :param x: [[...],[...],...,[...]]\n",
    "        :param y: [[...],[...],...,[...]]\n",
    "        :return\n",
    "        \"\"\"\n",
    "        # 为 x,y加头尾\n",
    "        x = [[self.input_status_head_tail[0]] + xi + [self.input_status_head_tail[1]] for xi in x]\n",
    "        y = [[self.output_status_head_tail[0]] + yi + [self.output_status_head_tail[1]] for yi in y]\n",
    "        self.FF.fit(x, y)\n",
    "        self.w = np.ones(len(self.FF.feature_funcs)) * 1e-5\n",
    "        for _ in range(0, self.epochs):\n",
    "            # 偷个懒，用随机梯度下降\n",
    "            for i in range(0, len(x)):\n",
    "                xi = x[i]\n",
    "                yi = y[i]\n",
    "                \"\"\"\n",
    "                1.求F(yi \\mid xi)以及P_w(yi \\mid xi)\n",
    "                \"\"\"\n",
    "                F_y_x = []\n",
    "                Z_x = np.ones(shape=(self.output_status_num, 1)).T\n",
    "                for j in range(1, len(xi)):\n",
    "                    F_y_x.append(self.FF.map(yi[j - 1], yi[j], xi, j))\n",
    "                    # 构建M矩阵\n",
    "                    M = np.zeros(shape=(self.output_status_num, self.output_status_num))\n",
    "                    for k in range(0, self.output_status_num):\n",
    "                        for t in range(0, self.output_status_num):\n",
    "                            M[k, t] = np.exp(np.dot(self.w, self.FF.map(k, t, xi, j)))\n",
    "                    # 前向算法求 Z(x)\n",
    "                    Z_x = Z_x.dot(M)\n",
    "                F_y_x = np.sum(F_y_x, axis=0)\n",
    "                Z_x = np.sum(Z_x)\n",
    "                # 求P_w(yi \\mid xi)\n",
    "                P_w = np.exp(np.dot(self.w, F_y_x)) / Z_x\n",
    "                \"\"\"\n",
    "                2.求梯度,并更新\n",
    "                \"\"\"\n",
    "                dw = (P_w - 1) * F_y_x\n",
    "                self.w = self.w - self.lr * dw\n",
    "                if (np.sqrt(np.dot(dw, dw) / len(dw))) < self.tol:\n",
    "                    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 随便测试一下\n",
    "x = [\n",
    "    [1, 2, 3, 0, 1, 3, 4],\n",
    "    [1, 2, 3],\n",
    "    [0, 2, 4, 2],\n",
    "    [4, 3, 2, 1],\n",
    "    [3, 1, 1, 1, 1],\n",
    "    [2, 1, 3, 2, 1, 3, 4]\n",
    "]\n",
    "y = x\n",
    "\n",
    "crf = CRF(output_status_num=5, input_status_num=5)\n",
    "crf.fit(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.00000000e-05, 1.00009053e-01, 7.00089277e-02, 7.00091106e-02,\n",
       "       2.00098984e-02, 4.00097839e-02, 6.00090103e-02, 1.00000000e-05,\n",
       "       2.00092452e-02, 2.00092452e-02, 1.00099996e-02, 1.00099996e-02,\n",
       "       3.00099986e-02, 2.00099991e-02, 2.00099991e-02, 2.00092452e-02,\n",
       "       2.00092452e-02, 2.00092452e-02, 1.00099996e-02, 1.00099996e-02,\n",
       "       3.00099986e-02, 2.00099991e-02, 2.00099991e-02, 1.00000000e-05,\n",
       "       2.00092452e-02, 2.00092452e-02, 1.00099996e-02, 1.00099996e-02,\n",
       "       3.00099986e-02, 2.00099991e-02, 2.00099991e-02, 2.00092452e-02,\n",
       "       2.00092452e-02, 2.00092452e-02, 1.00099996e-02, 1.00099996e-02,\n",
       "       3.00099986e-02, 2.00099991e-02, 2.00099991e-02, 2.00092452e-02,\n",
       "       2.00092452e-02, 2.00092452e-02, 1.00099996e-02, 1.00099996e-02,\n",
       "       3.00099986e-02, 2.00099991e-02, 2.00099991e-02, 1.00092456e-02,\n",
       "       1.00092456e-02, 1.00092456e-02, 1.00092456e-02, 1.00092456e-02,\n",
       "       1.00000000e-05, 1.00098988e-02, 1.00098988e-02, 1.00098988e-02,\n",
       "       1.00098988e-02, 1.00098988e-02, 1.00098988e-02, 1.00098988e-02,\n",
       "       1.00098988e-02, 1.00098988e-02, 1.00000000e-05, 1.00098988e-02,\n",
       "       1.00098988e-02, 1.00098988e-02, 1.00098988e-02, 1.00098988e-02,\n",
       "       1.00098988e-02, 1.00098988e-02, 1.00098988e-02, 1.00098988e-02,\n",
       "       1.00098988e-02, 1.00098988e-02, 1.00098988e-02, 1.00098988e-02,\n",
       "       1.00098988e-02, 1.00000000e-05, 1.00098860e-02, 2.00098855e-02,\n",
       "       3.00098850e-02, 2.00098668e-02, 1.00098860e-02, 1.00098860e-02,\n",
       "       2.00098855e-02, 3.00098850e-02, 2.00098668e-02, 1.00000000e-05,\n",
       "       1.00098860e-02, 2.00098855e-02, 3.00098850e-02, 2.00098668e-02,\n",
       "       1.00098860e-02, 1.00098860e-02, 2.00098855e-02, 3.00098850e-02,\n",
       "       2.00098668e-02, 1.00098860e-02, 1.00098860e-02, 2.00098855e-02,\n",
       "       3.00098850e-02, 2.00098668e-02, 1.00000000e-05, 1.00099808e-02,\n",
       "       3.00099424e-02, 1.00099808e-02, 1.00099808e-02, 3.00099424e-02,\n",
       "       1.00000000e-05, 1.00099808e-02, 3.00099424e-02, 1.00099808e-02,\n",
       "       1.00099808e-02, 3.00099424e-02, 1.00099808e-02, 1.00099808e-02,\n",
       "       3.00099424e-02, 1.00000000e-05, 1.00099995e-02, 1.00000000e-05,\n",
       "       1.00099995e-02, 1.00099995e-02])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "crf.w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "122"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(crf.w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可知一共有122个特征函数，接下来还剩最后一个问题，那就是标签预测的问题，见下一小节~~~"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
